"""
Generate publication-quality tables for the academic paper.

Outputs Markdown tables to paper/tables/ that pandoc can render.
"""

import pandas as pd
import numpy as np
from pathlib import Path

PROCESSED_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/processed")
OUTPUT_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/output")
PAPER_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/paper")
TAB_DIR = PAPER_DIR / "tables"
TAB_DIR.mkdir(parents=True, exist_ok=True)


def _stars(p):
    if pd.isna(p):
        return ''
    if p < 0.01:
        return '***'
    if p < 0.05:
        return '**'
    if p < 0.1:
        return '*'
    return ''


def _fmt(val, decimals=3):
    if pd.isna(val):
        return ''
    return f'{val:.{decimals}f}'


def generate_summary_statistics():
    """Table 2: Summary statistics for key variables."""
    panel = pd.read_csv(PROCESSED_DIR / "full_panel.csv")
    panel = panel[(panel['year'] >= 1986) & (panel['year'] <= 2024)]

    variables = {
        'ca_gdp': 'Current account / GDP (%)',
        'Z_1': 'Demographic polynomial Z_1',
        'Z_2': 'Demographic polynomial Z_2',
        'Z_3': 'Demographic polynomial Z_3',
        'fiscal_bal_gdp': 'Fiscal balance / GDP (%)',
        'kaopen': 'Capital openness (KAOPEN)',
        'expected_growth': 'Expected GDP growth (%)',
        'nfa_gdp_lag': 'Lagged NFA / GDP',
        'log_rel_opw': 'Log relative output per worker',
        'life_expectancy': 'Life expectancy (years)',
        'health_exp_gdp': 'Health expenditure / GDP (%)',
        'lending_rate': 'Lending rate (%)',
        'log_lending_rate': 'Log lending rate',
        'pension_spending_gdp': 'Old-age public spending / GDP (%)',
        'pension_coverage': 'Social insurance coverage (%)',
    }

    rows = []
    for var, label in variables.items():
        if var not in panel.columns:
            continue
        s = panel[var].dropna()
        rows.append({
            'Variable': label,
            'N': f'{len(s):,}',
            'Countries': str(panel.loc[s.index, 'iso3'].nunique()),
            'Mean': _fmt(s.mean()),
            'Std. Dev.': _fmt(s.std()),
            'Min': _fmt(s.min()),
            'Max': _fmt(s.max()),
        })

    df = pd.DataFrame(rows)

    # Write as markdown table
    lines = ['| Variable | N | Countries | Mean | Std. Dev. | Min | Max |',
             '|:---------|---:|----------:|-----:|----------:|----:|----:|']
    for _, row in df.iterrows():
        lines.append(f"| {row['Variable']} | {row['N']} | {row['Countries']} "
                      f"| {row['Mean']} | {row['Std. Dev.']} | {row['Min']} | {row['Max']} |")

    with open(TAB_DIR / "summary_statistics.md", 'w') as f:
        f.write('\n'.join(lines))
    print("  Saved summary_statistics.md")
    return df


def generate_regression_table():
    """Table 3: Main regression results (4 models side-by-side)."""
    models = {}
    # Try multiple filename patterns
    demo_files = ['regression_demographics_only.csv']
    base_files = ['regression_baseline_plus_eba.csv', 'regression_baseline_demo_plus_eba.csv']
    ext_files = ['regression_extended_plus_interactions.csv', 'regression_extended_plus_rates.csv']
    pension_files = ['regression_pension_model.csv']

    for fnames, label in [
        (demo_files, 'Demo Only'),
        (base_files, 'Baseline'),
        (ext_files, 'Extended'),
        (pension_files, 'Pension'),
    ]:
        for fname in fnames:
            f = OUTPUT_DIR / "tables" / fname
            if f.exists():
                models[label] = pd.read_csv(f)
                break

    if not models:
        print("  No regression tables found")
        return None

    # Collect all variables across models
    all_vars = []
    for df in models.values():
        for v in df['variable']:
            if v not in all_vars:
                all_vars.append(v)

    # Build table rows
    model_names = list(models.keys())
    header = '| Variable | ' + ' | '.join(f'({i+1}) {m}' for i, m in enumerate(model_names)) + ' |'
    separator = '|:---------|' + '|'.join(['---------:'] * len(model_names)) + '|'

    lines = [header, separator]

    for var in all_vars:
        # Coefficient row
        cells = [var.replace('_', '\\_')]
        for m in model_names:
            df = models[m]
            row = df[df['variable'] == var]
            if len(row) > 0:
                coef = row['coefficient'].values[0]
                p = row['p_value'].values[0]
                cells.append(f'{coef:.4f}{_stars(p)}')
            else:
                cells.append('')
        lines.append('| ' + ' | '.join(cells) + ' |')

        # Standard error row
        cells = ['']
        for m in model_names:
            df = models[m]
            row = df[df['variable'] == var]
            if len(row) > 0:
                se = row['std_error'].values[0]
                cells.append(f'({se:.4f})')
            else:
                cells.append('')
        lines.append('| ' + ' | '.join(cells) + ' |')

    # Model statistics
    comp = OUTPUT_DIR / "tables" / "model_comparison.csv"
    if comp.exists():
        comp_df = pd.read_csv(comp)
        lines.append('| ' + ' | '.join([''] * (len(model_names) + 1)) + ' |')
        for stat, fmt in [('N obs', ',.0f'), ('N countries', ',.0f'),
                           ('R²', '.3f'), ('Adj R²', '.3f'), ('ρ', '.3f')]:
            cells = [stat]
            for m in model_names:
                row = comp_df[comp_df['Model'].str.contains(m.split()[0], case=False)]
                if len(row) > 0 and stat in row.columns:
                    val = row[stat].values[0]
                    cells.append(f'{val:{fmt}}')
                else:
                    cells.append('')
            lines.append('| ' + ' | '.join(cells) + ' |')

    lines.append('')
    lines.append('*Notes:* \\*\\*\\* p<0.01, \\*\\* p<0.05, \\* p<0.1. '
                  'Standard errors in parentheses. '
                  'Pooled GLS with AR(1) correction.')

    with open(TAB_DIR / "regression_results.md", 'w') as f:
        f.write('\n'.join(lines))
    print("  Saved regression_results.md")


def generate_data_sources_table():
    """Table 1: Data sources and coverage."""
    rows = [
        ('UN WPP 2024', 'Population by 5-year age group', '237 countries, 1950-2100', '@unwpp2024'),
        ('IMF WEO (Apr 2025)', 'CA/GDP, fiscal balance, GDP growth, output gap', '196 countries, 1980-2030', ''),
        ('Penn World Tables 10.0', 'Output per worker, human capital', '183 countries, 1950-2019', '@pwt100'),
        ('World Bank WDI', 'Health expenditure, life expectancy', '193 countries, 1970-2023', ''),
        ('Chinn-Ito KAOPEN', 'Financial openness index', '181 countries, 1970-2023', '@chinn2006'),
        ('Lane & Milesi-Ferretti EWN', 'Net foreign assets / GDP', '212 countries, 1970-2024', '@lane2007'),
        ('IMF MFS\\_IR', 'Lending and money market rates', '150 countries, 1970-2024', ''),
        ('FRED', 'Government bond yields, short rates', '23 OECD countries, 1970-2025', ''),
        ('OECD SOCX', 'Old-age public spending / GDP', '42 OECD countries, 1980-2024', ''),
        ('World Bank ASPIRE', 'Social insurance coverage', '118 countries, 1999-2023', ''),
    ]

    lines = ['| Source | Variables | Coverage | Reference |',
             '|:-------|:----------|:---------|:----------|']
    for source, variables, coverage, ref in rows:
        lines.append(f'| {source} | {variables} | {coverage} | {ref} |')

    with open(TAB_DIR / "data_sources.md", 'w') as f:
        f.write('\n'.join(lines))
    print("  Saved data_sources.md")


def generate_structural_break_table():
    """Table 4: Split-sample structural break comparison."""
    f = OUTPUT_DIR / "tables" / "structural_break_summary.csv"
    if not f.exists():
        print("  No structural break summary found")
        return

    df = pd.read_csv(f)

    lines = ['| Specification | Period | N | Countries | R² | Z_1 | Z_2 | Z_3 |',
             '|:-------------|:-------|---:|----------:|----:|----:|----:|----:|']

    for _, row in df.iterrows():
        z1_sig = _stars(row.get('Z_1_pval', 1))
        z2_sig = _stars(row.get('Z_2_pval', 1))
        z3_sig = _stars(row.get('Z_3_pval', 1))
        lines.append(
            f"| {row['Specification']} | {row['Period']} "
            f"| {int(row['N obs']):,} | {int(row['N countries'])} "
            f"| {row['R²']:.3f} "
            f"| {row['Z_1_coef']:.2f}{z1_sig} "
            f"| {row['Z_2_coef']:.2f}{z2_sig} "
            f"| {row['Z_3_coef']:.2f}{z3_sig} |"
        )

    lines.append('')
    lines.append('*Notes:* \\*\\*\\* p<0.01, \\*\\* p<0.05, \\* p<0.1. '
                  'Pooled GLS with AR(1) correction. '
                  'Controls: fiscal balance, KAOPEN, expected growth, NFA/GDP, '
                  'relative output per worker, life expectancy.')

    with open(TAB_DIR / "structural_breaks.md", 'w') as f:
        f.write('\n'.join(lines))
    print("  Saved structural_breaks.md")


def generate_projection_table():
    """Table 5: Demographic projections (copy and format)."""
    f = OUTPUT_DIR / "tables" / "projection_table.csv"
    if not f.exists():
        print("  No projection table found")
        return

    df = pd.read_csv(f)

    id_col = 'Country' if 'Country' in df.columns else 'iso3'
    year_cols = [c for c in df.columns if c != id_col]
    lines = ['| Country | ' + ' | '.join(year_cols) + ' |',
             '|:--------|' + '|'.join(['---------:'] * len(year_cols)) + '|']

    for _, row in df.iterrows():
        cells = [row[id_col]]
        for yc in year_cols:
            val = row[yc]
            cells.append(_fmt(val, 2) if not pd.isna(val) else '')
        lines.append('| ' + ' | '.join(cells) + ' |')

    lines.append('')
    lines.append('*Notes:* Values are the estimated demographic contribution to CA/GDP '
                  'in percentage points, computed from baseline model coefficients '
                  'applied to UN WPP 2024 medium-variant population projections.')

    with open(TAB_DIR / "projections.md", 'w') as f:
        f.write('\n'.join(lines))
    print("  Saved projections.md")


def generate_residual_table():
    """Table 6: Top/bottom 10 countries by mean residual."""
    f = OUTPUT_DIR / "tables" / "country_residuals.csv"
    if not f.exists():
        print("  No residual table found")
        return

    df = pd.read_csv(f).sort_values('mean_resid')
    show = pd.concat([df.head(10), df.tail(10)])

    lines = ['| Country | Mean Residual | Std. Dev. | N obs | Mean CA/GDP |',
             '|:--------|-------------:|---------:|------:|-----------:|']

    for _, row in show.iterrows():
        lines.append(
            f"| {row['iso3']} | {row['mean_resid']:+.2f} | {row['std_resid']:.2f} "
            f"| {int(row['n_obs'])} | {row['mean_ca']:.2f} |"
        )

    lines.append('')
    lines.append('*Notes:* Residuals from the baseline model (Model 2). '
                  'Positive = actual CA/GDP exceeds model prediction. '
                  'Top 10 (most negative) and bottom 10 (most positive) shown.')

    with open(TAB_DIR / "country_residuals.md", 'w') as f:
        f.write('\n'.join(lines))
    print("  Saved country_residuals.md")


def generate_all_tables():
    """Generate all publication tables."""
    print("\n>>> Generating publication tables <<<")
    generate_data_sources_table()
    generate_summary_statistics()
    generate_regression_table()
    generate_structural_break_table()
    generate_projection_table()
    generate_residual_table()
    print("  All tables generated in paper/tables/")


if __name__ == "__main__":
    generate_all_tables()
